# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from .my_modules import MyNetwork

__all__ = [
	'make_divisible', 'build_activation', 'ShuffleLayer', 'MyGlobalAvgPool2d', 'Hswish', 'Hsigmoid', 'SEModule',
    'MultiHeadCrossEntropyLoss'
]


def make_divisible(v, divisor, min_val=None):
	"""
	This function is taken from the original tf repo.
	It ensures that all layers have a channel number that is divisible by 8
	It can be seen here:
	https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
	:param v:
	:param divisor:
	:param min_val:
	:return:
	"""
	if min_val is None:
		min_val = divisor
	new_v = max(min_val, int(v + divisor / 2) // divisor * divisor)
	# Make sure that round down does not go down by more than 10%.
	if new_v < 0.9 * v:
		new_v += divisor
	return new_v


def build_activation(act_func, inplace=True):
	if act_func == 'relu':
		return nn.ReLU(inplace=inplace)
	elif act_func == 'relu6':
		return nn.ReLU6(inplace=inplace)
	elif act_func == 'tanh':
		return nn.Tanh()
	elif act_func == 'sigmoid':
		return nn.Sigmoid()
	elif act_func == 'h_swish':
		return Hswish(inplace=inplace)
	elif act_func == 'h_sigmoid':
		return Hsigmoid(inplace=inplace)
	elif act_func is None or act_func == 'none':
		return None
	else:
		raise ValueError('do not support: %s' % act_func)


class ShuffleLayer(nn.Module):

	def __init__(self, groups):
		super(ShuffleLayer, self).__init__()
		self.groups = groups

	def forward(self, x):
		batch_size, num_channels, height, width = x.size()
		channels_per_group = num_channels // self.groups
		# reshape
		x = x.view(batch_size, self.groups, channels_per_group, height, width)
		x = torch.transpose(x, 1, 2).contiguous()
		# flatten
		x = x.view(batch_size, -1, height, width)
		return x

	def __repr__(self):
		return 'ShuffleLayer(groups=%d)' % self.groups


class MyGlobalAvgPool2d(nn.Module):

	def __init__(self, keep_dim=True):
		super(MyGlobalAvgPool2d, self).__init__()
		self.keep_dim = keep_dim

	def forward(self, x):
		return x.mean(3, keepdim=self.keep_dim).mean(2, keepdim=self.keep_dim)

	def __repr__(self):
		return 'MyGlobalAvgPool2d(keep_dim=%s)' % self.keep_dim


class Hswish(nn.Module):

	def __init__(self, inplace=True):
		super(Hswish, self).__init__()
		self.inplace = inplace

	def forward(self, x):
		return x * F.relu6(x + 3., inplace=self.inplace) / 6.

	def __repr__(self):
		return 'Hswish()'


class Hsigmoid(nn.Module):

	def __init__(self, inplace=True):
		super(Hsigmoid, self).__init__()
		self.inplace = inplace

	def forward(self, x):
		return F.relu6(x + 3., inplace=self.inplace) / 6.

	def __repr__(self):
		return 'Hsigmoid()'


class SEModule(nn.Module):
	REDUCTION = 4

	def __init__(self, channel, reduction=None):
		super(SEModule, self).__init__()

		self.channel = channel
		self.reduction = SEModule.REDUCTION if reduction is None else reduction

		num_mid = make_divisible(self.channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE)

		self.fc = nn.Sequential(OrderedDict([
			('reduce', nn.Conv2d(self.channel, num_mid, 1, 1, 0, bias=True)),
			('relu', nn.ReLU(inplace=True)),
			('expand', nn.Conv2d(num_mid, self.channel, 1, 1, 0, bias=True)),
			('h_sigmoid', Hsigmoid(inplace=True)),
		]))

	def forward(self, x):
		y = x.mean(3, keepdim=True).mean(2, keepdim=True)
		y = self.fc(y)
		return x * y

	def __repr__(self):
		return 'SE(channel=%d, reduction=%d)' % (self.channel, self.reduction)


class MultiHeadCrossEntropyLoss(nn.Module):

	def forward(self, outputs, targets):
		assert outputs.dim() == 3, outputs
		assert targets.dim() == 2, targets

		assert outputs.size(1) == targets.size(1), (outputs, targets)
		num_heads = targets.size(1)

		loss = 0
		for k in range(num_heads):
			loss += F.cross_entropy(outputs[:, k, :], targets[:, k]) / num_heads
		return loss
